# -*- coding: utf-8 -*-
'''
Created on 2016年12月26日

@author: ZhuJiahui
'''

import os
import time
import numpy as np
from ete3 import Tree
from topic_utils.distribution_util import merge_topic_vsm, merge_rfc,\
    merge_center_topic_by_support, get_topic_rfc
from metric_utils.kl_divergence import symmetric_kld
from mv_brt.marginal_distribution import dcm_margin2


class BayesianRoseTree(object):
    '''
    贝叶斯玫瑰树
    基于分布相似度进行增强
    '''

    def __init__(self):
        print("贝叶斯玫瑰树初始化")
        self.tree_instance = Tree()

    def construct_brt(self, topic_word_distribution, related_feed_collection, feed_num):
        '''
        贝叶斯玫瑰树的构建过程
        :param topic_word_distribution: 主题-词汇分布向量
        :param related_feed_collection: 主题相关文档集合
        :param feed_num: 文档总数目
        :return ete3.Tree类型的贝叶斯玫瑰树
        '''

        topic_num = len(topic_word_distribution)  # 主题数
        vocabulary_num = len(topic_word_distribution[0])  # 词汇列表长度
        dir_alpha = 0.01 * np.ones(vocabulary_num)  # Dirichlet先验参数
        pai_gamma = 0.2  # 划分粒度控制参数
        
        # 主题初始生成概率
        init_probability = np.zeros(topic_num)
        for i in range(topic_num):
            init_probability[i] = np.true_divide(len(related_feed_collection[i]), feed_num)
        
        brt_list = []  # 树节点序列
        rest_tree_num = topic_num  # 主题数迭代器
        
        # 初始化
        # 使用ete3中的Tree
        # 初始时每个主题都有一棵树
        for i in range(topic_num):
            brt_i = Tree()  # 每一棵树(根节点)
            brt_i.add_features(topic_vsm=topic_word_distribution[i])
            brt_i.add_features(likelihood=init_probability[i])
            brt_i.add_features(rfc=related_feed_collection[i])
            brt_i.add_features(center_topic=topic_word_distribution[i])
            brt_i.add_features(intersect_rfc=related_feed_collection[i])
            
            # 将每一个主题作为树节点加入到各自的树中
            leaf_node = brt_i.add_child(name=str(i))  # 节点的名字为id
            leaf_node.add_features(topic_vsm=topic_word_distribution[i])
            leaf_node.add_features(likelihood=init_probability[i])
            leaf_node.add_features(rfc=related_feed_collection[i])
            leaf_node.add_features(center_topic=topic_word_distribution[i])
            leaf_node.add_features(intersect_rfc=related_feed_collection[i])
            
            # 将每一棵树都加入到树序列中
            brt_list.append(brt_i)
        
        while rest_tree_num > 1:
            # 树两两之间的相似度矩阵
            # 分join/absorb/collapse三种
            similarity_matrix1 = np.zeros((rest_tree_num, rest_tree_num))
            similarity_matrix2 = np.zeros((rest_tree_num, rest_tree_num))
            similarity_matrix3 = np.zeros((rest_tree_num, rest_tree_num))
            
            # 减小搜索空间
            # 预先进行分布距离计算
            NN = 5  # 候选近邻数
            if rest_tree_num > NN:
                # 主题分布之间的KL距离矩阵
                # 用于查找最近邻
                topic_kld_matrix = np.zeros((rest_tree_num, rest_tree_num))
                for i in range(rest_tree_num):
                    for j in range(i, rest_tree_num):
                        if (i == j):
                            topic_kld_matrix[i, j] = np.power(2.0, 50)  # 自身之间赋予一个比较大的距离
                        else:
                            topic_kld_matrix[i, j] = symmetric_kld(brt_list[i].center_topic, brt_list[j].center_topic)
                            topic_kld_matrix[j, i] = topic_kld_matrix[i, j]
                
                for i in range(rest_tree_num):
                    NN_set = np.argsort(topic_kld_matrix[i])[:NN]
                    for each in NN_set:
                        # 不会跟自己比较
                        # 合并向量空间
                        # topic_vsm是每一个节点的所有主题的向量空间
                        merged_topic_vsm = merge_topic_vsm(brt_list[i].topic_vsm, brt_list[each].topic_vsm)
                        # 计算主题分布相似度
                        topic_sim = np.true_divide(1.0, (1.0 + symmetric_kld(brt_list[i].center_topic, brt_list[each].center_topic)))
                        # 计算边缘概率
                        # marginal_probability = dcm_margin2(merged_topic_vsm, dir_alpha)
                        marginal_probability = len(set(brt_list[i].intersect_rfc).intersection(set(brt_list[each].intersect_rfc))) / feed_num
                        similarity_matrix1[i, each] = self.compute_join_likelihood(brt_list[i], brt_list[each], marginal_probability, topic_sim, pai_gamma)
                        similarity_matrix2[i, each] = self.compute_absorb_likelihood(brt_list[i], brt_list[each], marginal_probability, topic_sim, pai_gamma)
                        similarity_matrix3[i, each] = self.compute_collapse_likelihood(brt_list[i], brt_list[each], marginal_probability, topic_sim, pai_gamma)
                         
                        similarity_matrix1[each, i] = similarity_matrix1[i, each]
                        similarity_matrix2[each, i] = self.compute_absorb_likelihood(brt_list[each], brt_list[i], marginal_probability, topic_sim, pai_gamma)
                        similarity_matrix3[each, i] = similarity_matrix3[i, each]
               
            else:
                for i in range(rest_tree_num):
                    for j in range(i, rest_tree_num):
                        if (i == j):
                            similarity_matrix1[i, j] = 0.0
                            similarity_matrix2[i, j] = 0.0
                            similarity_matrix3[i, j] = 0.0
                        else:
                            # 合并向量空间
                            # topic_vsm是每一个节点的所有主题的向量空间
                            merged_topic_vsm = merge_topic_vsm(brt_list[i].topic_vsm, brt_list[j].topic_vsm)
                            # 计算主题分布相似度
                            topic_sim = np.true_divide(1.0, (1.0 + symmetric_kld(brt_list[i].center_topic, brt_list[j].center_topic)))
                            # 计算边缘概率
                            # marginal_probability = dcm_margin2(merged_topic_vsm, dir_alpha)
                            marginal_probability = len(set(brt_list[i].intersect_rfc).intersection(set(brt_list[j].intersect_rfc))) / feed_num
                            similarity_matrix1[i, j] = self.compute_join_likelihood(brt_list[i], brt_list[j], marginal_probability, topic_sim, pai_gamma)
                            similarity_matrix2[i, j] = self.compute_absorb_likelihood(brt_list[i], brt_list[j], marginal_probability, topic_sim, pai_gamma)
                            similarity_matrix3[i, j] = self.compute_collapse_likelihood(brt_list[i], brt_list[j], marginal_probability, topic_sim, pai_gamma)

                            similarity_matrix1[j, i] = similarity_matrix1[i, j]
                            similarity_matrix2[j, i] = self.compute_absorb_likelihood(brt_list[j], brt_list[i], marginal_probability, topic_sim, pai_gamma)
                            similarity_matrix3[j, i] = similarity_matrix3[i, j]
            
            # 选取值最大的元素
            max_list = np.zeros(3)
            max_list[0] = np.max(similarity_matrix1)
            max_list[1] = np.max(similarity_matrix2)
            max_list[2] = np.max(similarity_matrix3)
            
            # print max_list
            max_likelihood = np.max(max_list)
            max_list_index = np.argmax(max_list)
    
            '''
            合并Join,吸收Absorb,塌陷collapse
            '''
            # join 合并操作
            if max_list_index == 0:
                self.brt_join(brt_list, similarity_matrix1, rest_tree_num, max_likelihood)
                print("Join")
            # absorb 吸收操作
            elif max_list_index == 1:
                self.brt_absorb(brt_list, similarity_matrix2, rest_tree_num, max_likelihood)
                print("Absorb")
            # collapse 塌陷操作
            elif max_list_index == 2:
                self.brt_collapse(brt_list, similarity_matrix3, rest_tree_num, max_likelihood)
                print("Collapse")
            else:
                print("Error!")
                break
            
            rest_tree_num = rest_tree_num - 1
        
        # 跳出循环
    
        # print(this_BRT.get_ascii(show_internal=True))
        # return this_BRT
        self.tree_instance = brt_list[0]

    def compute_join_likelihood(self, left_tree, right_tree, marginal_probability, topic_sim, pai_gamma):
        '''
        BRT join操作生成概率
        :param left_tree: 左子树
        :param right_tree: 右子树
        :param marginal_probability: 主题集合的边缘概率
        :param topic_sim: 主题分布相似度
        :param pai_gamma: 划分粒度控制参数
        :return 该操作的生成概率
        '''
    
        nTm = 2
        pai_Tm = 1 - np.power((1 - pai_gamma), (nTm - 1))
        pZm_Tm = (pai_Tm * marginal_probability + (1 - pai_Tm) * left_tree.likelihood * right_tree.likelihood) * topic_sim
        
        return pZm_Tm

    def compute_absorb_likelihood(self, left_tree, right_tree, marginal_probability, topic_sim, pai_gamma):
        '''
        BRT absorb操作生成概率
        :param left_tree: 左子树
        :param right_tree: 右子树
        :param marginal_probability: 主题集合的边缘概率
        :param topic_sim: 主题分布相似度
        :param pai_gamma: 划分粒度控制参数
        :return 该操作的生成概率
        '''
        nTm = len(left_tree.children) + 1
        pai_Tm = 1 - np.power((1 - pai_gamma), (nTm - 1))
        temp_product = 1.0
        for each in left_tree.children:
            temp_product = temp_product * each.likelihood
            
        pZm_Tm = (pai_Tm * marginal_probability + (1 - pai_Tm) * right_tree.likelihood * temp_product) * topic_sim
        
        return pZm_Tm

    def compute_collapse_likelihood(self, left_tree, right_tree, marginal_probability, topic_sim, pai_gamma):
        '''
        BRT collapse操作生成概率
        :param left_tree: 左子树
        :param right_tree: 右子树
        :param marginal_probability: 主题集合的边缘概率
        :param topic_sim: 主题分布相似度
        :param pai_gamma: 划分粒度控制参数
        :return 该操作的生成概率
        '''
        
        nTm = len(left_tree.children) + len(right_tree.children)
        pai_Tm = 1 - np.power((1 - pai_gamma), (nTm - 1))
        
        temp_product = 1.0
        for each in left_tree.children:
            temp_product = temp_product * each.likelihood
        for each2 in right_tree.children:
            temp_product = temp_product * each2.likelihood
        
        pZm_Tm = (pai_Tm * marginal_probability + (1 - pai_Tm) * temp_product) * topic_sim
        
        return pZm_Tm

    def prepare_operation(self, brt_list, similarity_matrix, rest_tree_num):
        '''
        操作阶段准备
        获取相应的左右子树的索引以及进行属性的合并
        :param brt_list: 树集合
        :param similarity_matrix: 生成概率矩阵
        :param rest_tree_num: 当前BRT树集合中的元素个数
        :return max_row_index, max_col_index, left_children_num, right_children_num, \
            merged_vsm, merged_rfc, new_center_topic, new_intersect_rfc
        '''
        
        max_tree_index1 = np.argmax(similarity_matrix)  # 返回最大值元素在矩阵中的位置(一维坐标)
        # 将一维坐标转化为二维坐标
        max_row_index = int(np.ceil(np.true_divide((max_tree_index1 + 1), rest_tree_num)) - 1)
        max_col_index = int(max_tree_index1 - max_row_index * rest_tree_num)
        
        # 获取左右两颗子树的孩子数
        left_children_num = len(brt_list[max_row_index].children)
        right_children_num = len(brt_list[max_col_index].children)
        
        # 合并向量空间和概率
        merged_vsm = merge_topic_vsm(brt_list[max_row_index].topic_vsm, brt_list[max_col_index].topic_vsm)
        merged_rfc = merge_rfc(brt_list[max_row_index].rfc, brt_list[max_col_index].rfc)
        new_center_topic = merge_center_topic_by_support(brt_list[max_row_index].center_topic, left_children_num, brt_list[max_col_index].center_topic, right_children_num)
        new_intersect_rfc = list(set(brt_list[max_row_index].intersect_rfc).intersection(brt_list[max_col_index].intersect_rfc))
        
        return max_row_index, max_col_index, left_children_num, right_children_num, \
            merged_vsm, merged_rfc, new_center_topic, new_intersect_rfc
    
    def brt_join(self, brt_list, similarity_matrix1, rest_tree_num, max_likelihood):
        '''
        真正实施BRT join操作
        :param brt_list: 树集合
        :param similarity_matrix1: 生成概率矩阵
        :param rest_tree_num: 当前BRT树集合中的元素个数
        :param max_likelihood: 最大可能性下的概率值
        :return void
        '''

        max_row_index, max_col_index, left_children_num, right_children_num, \
            merged_vsm, merged_rfc, new_center_topic, new_intersect_rfc \
            = self.prepare_operation(brt_list, similarity_matrix1, rest_tree_num)

        if left_children_num == 1 and right_children_num == 1:
            brt_list[max_row_index].add_child(brt_list[max_col_index].children[0])  # 只有1个孩子
                
            brt_list[max_row_index].topic_vsm = merged_vsm
            brt_list[max_row_index].likelihood = max_likelihood
            brt_list[max_row_index].rfc = merged_rfc
            brt_list[max_row_index].center_topic = new_center_topic
            brt_list[max_row_index].intersect_rfc = new_intersect_rfc
                    
            # 删除另一个部分
            brt_list.remove(brt_list[max_col_index])
        elif left_children_num == 1 and right_children_num > 1:
            brt_list[max_row_index].add_child(brt_list[max_col_index])
                    
            brt_list[max_row_index].topic_vsm = merged_vsm
            brt_list[max_row_index].likelihood = max_likelihood
            brt_list[max_row_index].rfc = merged_rfc
            brt_list[max_row_index].center_topic = new_center_topic
            brt_list[max_row_index].intersect_rfc = new_intersect_rfc
                    
            # 删除另一个部分
            brt_list.remove(brt_list[max_col_index])
        elif left_children_num > 1 and right_children_num == 1:
            brt_list[max_col_index].add_child(brt_list[max_row_index])
                    
            brt_list[max_col_index].topic_vsm = merged_vsm
            brt_list[max_col_index].likelihood = max_likelihood
            brt_list[max_col_index].rfc = merged_rfc
            brt_list[max_col_index].center_topic = new_center_topic
            brt_list[max_col_index].intersect_rfc = new_intersect_rfc
                    
            # 删除另一个部分
            brt_list.remove(brt_list[max_row_index])
        else:
            # 引入新的父节点
            brt_merge = Tree()
            
            brt_merge.add_features(topic_vsm=merged_vsm)
            brt_merge.add_features(likelihood=max_likelihood)
            brt_merge.add_features(rfc=merged_rfc)
            brt_merge.add_features(center_topic=new_center_topic)
            brt_merge.add_features(intersect_rfc=new_intersect_rfc)
                    
            brt_merge.add_child(brt_list[max_row_index])
            brt_merge.add_child(brt_list[max_col_index])
                    
            # 维护树结构
            brt_list[max_row_index] = brt_merge
            brt_list.remove(brt_list[max_col_index])

    def brt_absorb(self, brt_list, similarity_matrix2, rest_tree_num, max_likelihood):
        '''
        真正实施BRT absorb操作(left吸收right)
        :param brt_list: 树集合
        :param similarity_matrix2: 生成概率矩阵
        :param rest_tree_num: 当前BRT树集合中的元素个数
        :param max_likelihood: 最大可能性下的概率值
        :return void
        '''
        
        max_row_index, max_col_index, left_children_num, right_children_num, \
            merged_vsm, merged_rfc, new_center_topic, new_intersect_rfc \
            = self.prepare_operation(brt_list, similarity_matrix2, rest_tree_num)

        if left_children_num >= 1 and right_children_num == 1:
            # 退化为join
            brt_list[max_row_index].add_child(brt_list[max_col_index].children[0])
                
            brt_list[max_row_index].topic_vsm = merged_vsm
            brt_list[max_row_index].likelihood = max_likelihood
            brt_list[max_row_index].rfc = merged_rfc
            brt_list[max_row_index].center_topic = new_center_topic
            brt_list[max_row_index].intersect_rfc = new_intersect_rfc
                    
            # 删除另一个部分
            brt_list.remove(brt_list[max_col_index])
        else:
            brt_list[max_row_index].add_child(brt_list[max_col_index])
                
            brt_list[max_row_index].topic_vsm = merged_vsm
            brt_list[max_row_index].likelihood = max_likelihood
            brt_list[max_row_index].rfc = merged_rfc
            brt_list[max_row_index].center_topic = new_center_topic
            brt_list[max_row_index].intersect_rfc = new_intersect_rfc
                    
            # 删除另一个部分
            brt_list.remove(brt_list[max_col_index])

    def brt_collapse(self, brt_list, similarity_matrix3, rest_tree_num, max_likelihood):
        '''
        真正实施BRT collapse操作
        :param brt_list: 树集合
        :param similarity_matrix3: 生成概率矩阵
        :param rest_tree_num: 当前BRT树集合中的元素个数
        :param max_likelihood: 最大可能性下的概率值
        :return void
        '''
        
        max_row_index, max_col_index, _, _, \
            merged_vsm, merged_rfc, new_center_topic, new_intersect_rfc \
            = self.prepare_operation(brt_list, similarity_matrix3, rest_tree_num)

        right_children_list = brt_list[max_col_index].get_children()
        for each in right_children_list:
            brt_list[max_row_index].add_child(each)
           
        brt_list[max_row_index].topic_vsm = merged_vsm
        brt_list[max_row_index].likelihood = max_likelihood
        brt_list[max_row_index].rfc = merged_rfc
        brt_list[max_row_index].center_topic = new_center_topic
        brt_list[max_row_index].intersect_rfc = new_intersect_rfc
                    
        # 删除另一个部分
        brt_list.remove(brt_list[max_col_index])


def test_brt1():
    start = time.clock()
    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory1 = root_directory + u'dataset2/LDA/feed_topic20'
    read_directory2 = root_directory + u'dataset2/LDA/topic_word20'
    # write_directory = root_directory + u'dataset/sparse_topic'
    # if (not(os.path.exists(write_directory))):
        # os.mkdir(write_directory)

    dt_vsm = np.loadtxt(read_directory1 + '/' + str(432) + '.txt')
    topic_word_distribution = np.loadtxt(read_directory2 + '/' + str(432) + '.txt')
    
    related_feed_collection = get_topic_rfc(dt_vsm, 0.0)
    feed_num = len(dt_vsm)
    
    this_brt = BayesianRoseTree()
    this_brt.construct_brt(topic_word_distribution, related_feed_collection, feed_num)
    
    # for each in this_brt.tree_instance.children:
        # print(each.center_topic)

    print(this_brt.tree_instance.get_ascii(show_internal=True))
    #lca = this_BRT.get_common_ancestor("6", "1")
    #print lca.gen_p
    #print lca.t_vsm
    print('Total time %f seconds' % (time.clock() - start)) 

if __name__ == '__main__':
    test_brt1()
